#! /usr/bin/env python
# -*- coding: utf-8 -*-#
'''
# @Time : 2023/10/16 23:29
# @Author : Shiyu He
# @University : Xinjiang University
'''
import csv

import textgrid
import json
import wave
import copy

def exchange_error_mark(error_mark):
    if error_mark == "换":
        return "s"
    elif error_mark == "插":
        return "a"
    elif error_mark == "删":
        return "d"
    else:
        return "c"


def insert_to_every_interval(tier_label, tier_mark, tier_rec, tier_pinyin, tier_phone, label, mark, rec, pinyin, start, end):
    strip_punc = "[]"
    print("duration:", end)
    tier_label.addInterval(textgrid.Interval(minTime=start, maxTime=end, mark=label))
    tier_mark.addInterval(textgrid.Interval(minTime=start, maxTime=end, mark=exchange_error_mark(mark)))
    tier_rec.addInterval(textgrid.Interval(minTime=start, maxTime=end, mark=rec))
    tier_pinyin.addInterval(textgrid.Interval(minTime=start, maxTime=end, mark=','.join(pinyin).strip(strip_punc).replace("'","")))
    tier_phone.addInterval(textgrid.Interval(minTime=start, maxTime=end, mark=""))


def insert_sil_to_interval(tier_label, tier_mark, tier_rec, tier_pinyin, tier_phone, start, end):
    tier_label.addInterval(textgrid.Interval(minTime=start, maxTime=end, mark="sil"))
    tier_mark.addInterval(textgrid.Interval(minTime=start, maxTime=end, mark="sil"))
    tier_rec.addInterval(textgrid.Interval(minTime=start, maxTime=end, mark="sil"))
    tier_pinyin.addInterval(textgrid.Interval(minTime=start, maxTime=end, mark="sil"))
    tier_phone.addInterval(textgrid.Interval(minTime=start, maxTime=end, mark="sil"))


def check_and_change_every_tier_maxTime(tier_label, tier_mark, tier_rec, tier_pinyin, tier_phone, end, duration):
    if end > duration:
        # 遍历每个层级并修改maxTime
        for tier in [tier_label, tier_mark, tier_rec, tier_pinyin, tier_phone]:
            tier.maxTime = end
        duration = end
    return duration

def process_raw_textgrid(audio_name, duration, aligned_dict, timestamp_dict):
    max_align_len = len(aligned_dict[audio_name]["error_mark"])
    max_timestamp_len = len(timestamp_dict[audio_name]["time_stamp"])
    previous_end = 0.0
    #　如果你添加5层，那得到的grid size=5
    for i in range(5):
        tier_label = textgrid.IntervalTier(name="label", minTime=0., maxTime=duration)  # 添加一层,命名为word层
        tier_mark = textgrid.IntervalTier(name="errortype", minTime=0., maxTime=duration)  # 添加一层,命名为error_mark标记层
        tier_rec = textgrid.IntervalTier(name="rec", minTime=0., maxTime=duration)  # 添加一层,命名为error_mark标记层
        tier_pinyin = textgrid.IntervalTier(name="pinyin", minTime=0., maxTime=duration)  # 添加一层,命名为pinyin拼音层
        tier_phone = textgrid.IntervalTier(name="phone", minTime=0., maxTime=duration)  # 添加一层,命名为phone音素层
    # 由于YZS插入/删除错误处没有相应时间戳，timestamp_index只在非“删”错误才会自增，防止时间戳json索引越界
    # align_index任何情况都自增，其作为对齐json的索引，用来判断错误类型是否为"删"和"插"
    # 以及当错误类型为“删”/"插"时，将对应 对齐json 的内容填充到 删/插 错误对应的textgrid中
    timestamp_index = 0
    align_index = 0
    # print("index_len:",index_len)
    try:
        # 此处50是TextGrid文件增加增读/漏读时间戳的最大个数，也可以设置为无穷大
        while (align_index < max_align_len + 50 ) and (timestamp_index < max_timestamp_len + 50 ):
            if align_index == 0 and (aligned_dict[audio_name]["error_mark"][align_index] == "删" or aligned_dict[audio_name]["error_mark"][align_index] == "插"):
                start = 0.0
                # 请注意此处0.02为人为控制，为增读/漏读插入的时间戳的长度，太长会影响后续时间戳
                end = start + 0.02
                py_mark = [aligned_dict[audio_name]["py_label"][align_index],aligned_dict[audio_name]["py_rec"][align_index]]
                insert_to_every_interval(tier_label, tier_mark, tier_rec, tier_pinyin, tier_phone, aligned_dict[audio_name]["lab"][align_index], aligned_dict[audio_name]["error_mark"][align_index], aligned_dict[audio_name]["rec"][align_index], py_mark, start, end)
                previous_end = end
                align_index += 1
                continue
            elif align_index != 0 and (aligned_dict[audio_name]["error_mark"][align_index] == "删" or aligned_dict[audio_name]["error_mark"][align_index] == "插"):
                print("{} 第 {} lab: {}".format(audio_name, align_index, aligned_dict[audio_name]["lab"][align_index]))
                print("{} 第 {} rec: {}".format(audio_name, align_index, aligned_dict[audio_name]["rec"][align_index]))
                print("{} 第 {} error_mark: {}".format(audio_name, align_index, aligned_dict[audio_name]["error_mark"][align_index]))
                print()
                start = previous_end
                # 请注意此处0.02为人为控制，为增读/漏读插入的时间戳的长度，太长会影响后续时间戳
                end = start + 0.02
                # check是否在添加某些增读/漏读时间戳后超出了音频的最大持续时间，是则改变各层的maxTime
                duration = check_and_change_every_tier_maxTime(tier_label, tier_mark, tier_rec, tier_pinyin, tier_phone, end, duration)
                py_mark = [aligned_dict[audio_name]["py_label"][align_index],aligned_dict[audio_name]["py_rec"][align_index]]
                insert_to_every_interval(tier_label, tier_mark, tier_rec, tier_pinyin, tier_phone, aligned_dict[audio_name]["lab"][align_index], aligned_dict[audio_name]["error_mark"][align_index], aligned_dict[audio_name]["rec"][align_index], py_mark, start, end)
                previous_end = end
                align_index += 1
                continue
            
            print("*{} 第 {} lab: {}".format(audio_name, align_index, aligned_dict[audio_name]["lab"][align_index]))
            print("*{} 第 {} rec: {}".format(audio_name, align_index, aligned_dict[audio_name]["rec"][align_index]))
            print("*{} 第 {} error_mark: {}".format(audio_name, align_index, aligned_dict[audio_name]["error_mark"][align_index]))
            print("*{} 第 {} time_stamp: {}".format(audio_name, timestamp_index, timestamp_dict[audio_name]["time_stamp"][timestamp_index]))
            print()

            start = timestamp_dict[audio_name]["time_stamp"][timestamp_index][0]   # 时间戳单位变为s
            # 如果为增读/漏读处插入时间戳，造成时间戳溢出到下一个正常的时间戳
            if start < previous_end:
                # 将下一时间戳的start后退，后退为上个插入时间戳的end
                start = previous_end
            # 如果插入时间戳的end，与下一时间戳的start依然留有空隙
            elif start > previous_end:
                # 空列插入sil
                insert_sil_to_interval(tier_label, tier_mark, tier_rec, tier_pinyin, tier_phone, previous_end, start )
            end = timestamp_dict[audio_name]["time_stamp"][timestamp_index][1]
            # 当增读/漏读处插入时间戳造成start后移，越过了对应的end
            if end <= start:
                # 则需要将end也后移，后移到start后0.1处，也是为人为控制，不宜过大
                end = start + 0.15
            # check是否在添加某些漏读时间戳后超出了音频的最大持续时间，是则改变各层的maxTime
            duration = check_and_change_every_tier_maxTime(tier_label, tier_mark, tier_rec, tier_pinyin, tier_phone, end, duration)
            
            py_mark = [aligned_dict[audio_name]["py_label"][align_index],aligned_dict[audio_name]["py_rec"][align_index]]
            insert_to_every_interval(tier_label, tier_mark, tier_rec, tier_pinyin, tier_phone, aligned_dict[audio_name]["lab"][align_index], aligned_dict[audio_name]["error_mark"][align_index], aligned_dict[audio_name]["rec"][align_index], py_mark, start, end)
            previous_end = end
            timestamp_index += 1
            align_index += 1
    except IndexError:
        print("Index out of range error occurred. Skipping this iteration.")
    
    
    if previous_end < duration:
        # 所有行最后一列插入sil
        insert_sil_to_interval(tier_label, tier_mark, tier_rec, tier_pinyin, tier_phone, previous_end, duration)
    return tier_label, tier_mark, tier_rec, tier_pinyin, tier_phone


def generate_textgrid(audio_name, aligned_file_json, time_stamp_json):
    with open(aligned_file_json, 'r', encoding='utf-8') as f_align:
        aligned_dict = json.load(f_align)
    with open(time_stamp_json, 'r', encoding='utf-8') as f_timestamp:
        timestamp_dict = json.load(f_timestamp)
    print("{}  error_mark_len: {}".format(audio_name, len(aligned_dict[audio_name]["error_mark"])))
    print("{}  time_stamp_len: {}".format(audio_name, len(timestamp_dict[audio_name]["time_stamp"])))
    
    duration = timestamp_dict[audio_name]["duration"]
    tg = textgrid.TextGrid(minTime=0, maxTime=duration)
    tier_label, tier_mark, tier_rec, tier_pinyin, tier_phone = process_raw_textgrid(audio_name, duration, aligned_dict, timestamp_dict)
    # 添加到tg对象中
    tg.tiers.append(tier_label)
    tg.tiers.append(tier_mark)
    tg.tiers.append(tier_rec)
    tg.tiers.append(tier_pinyin)
    tg.tiers.append(tier_phone)
    # print(tg.__dict__)
    if (aligned_dict[audio_name]["all"] + aligned_dict[audio_name]["ins"] - aligned_dict[audio_name]["del"]) == len(timestamp_dict[audio_name]["time_stamp"]):
        save_path = output_tg_path + '\\' + str(audio_name) + ".TextGrid"
        print(str(audio_name) + ".TextGrid已成功生成！")
        print()
    else:
        # save_path = output_json_path + '\\' + str("#" + audio_name) + ".TextGrid"    note: 这里因为上传时的格式需要，统一将tg前的# 去掉
        save_path = output_tg_path + '\\' + str(audio_name) + ".TextGrid"
        print(str(audio_name) + ".TextGrid已生成但缺少 {} 个时间戳！".format((aligned_dict[audio_name]["all"] + aligned_dict[audio_name]["ins"] - aligned_dict[audio_name]["del"]) - len(timestamp_dict[audio_name]["time_stamp"])))
        print()
    tg.write(save_path)


if __name__ == '__main__':
    # align_path = 'result/alignment.json'
    # timastamp_path = 'result/timestamp.json'
    align_path = r'E:\正式标注数据\已选数据\textgrid生成所需文件\alignment.json'
    timastamp_path = r'E:\正式标注数据\已选数据\textgrid生成所需文件\timestamp.json'
    scp_label = r'E:\正式标注数据\已选数据\textgrid生成所需文件\label.txt'
    output_tg_path = r'E:\正式标注数据\已选数据\wer0-50andtext10-15-tg'
    with open(scp_label, 'r', encoding='utf-8') as f_scp:
        for line in f_scp:
            line = line.strip()  # 去除行首行尾的空白字符
            if line:
                audio_name = line.split()[0]  # 使用空格分割行，并提取第一个字符串
                generate_textgrid(audio_name, align_path, timastamp_path)
